import os
from dataclasses import dataclass
from typing import final
import configparser

from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm

if not pm.is_installed("neo4j"):
    pm.install("neo4j")

from neo4j import (
    AsyncGraphDatabase,
    AsyncManagedTransaction,
)

from dotenv import load_dotenv

# use the .env that is inside the current folder
load_dotenv(dotenv_path=".env", override=False)

MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))

config = configparser.ConfigParser()
config.read("config.ini", "utf-8")


@final
@dataclass
class MemgraphStorage(BaseGraphStorage):
    def __init__(self, namespace, global_config, embedding_func, workspace=None):
        memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
        if memgraph_workspace and memgraph_workspace.strip():
            workspace = memgraph_workspace
        super().__init__(
            namespace=namespace,
            workspace=workspace or "",
            global_config=global_config,
            embedding_func=embedding_func,
        )
        self._driver = None

    def _get_workspace_label(self) -> str:
        """Get workspace label, return 'base' for compatibility when workspace is empty"""
        workspace = getattr(self, "workspace", None)
        return workspace if workspace else "base"

    async def initialize(self):
        URI = os.environ.get(
            "MEMGRAPH_URI",
            config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
        )
        USERNAME = os.environ.get(
            "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
        )
        PASSWORD = os.environ.get(
            "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
        )
        DATABASE = os.environ.get(
            "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
        )

        self._driver = AsyncGraphDatabase.driver(
            URI,
            auth=(USERNAME, PASSWORD),
        )
        self._DATABASE = DATABASE
        try:
            async with self._driver.session(database=DATABASE) as session:
                # Create index for base nodes on entity_id if it doesn't exist
                try:
                    workspace_label = self._get_workspace_label()
                    await session.run(
                        f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
                    )
                    logger.info(
                        f"Created index on :{workspace_label}(entity_id) in Memgraph."
                    )
                except Exception as e:
                    # Index may already exist, which is not an error
                    logger.warning(
                        f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
                    )
                await session.run("RETURN 1")
                logger.info(f"Connected to Memgraph at {URI}")
        except Exception as e:
            logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
            raise

    async def finalize(self):
        if self._driver is not None:
            await self._driver.close()
            self._driver = None

    async def __aexit__(self, exc_type, exc, tb):
        await self.finalize()

    async def index_done_callback(self):
        # Memgraph handles persistence automatically
        pass

    async def has_node(self, node_id: str) -> bool:
        """
        Check if a node exists in the graph.

        Args:
            node_id: The ID of the node to check.

        Returns:
            bool: True if the node exists, False otherwise.

        Raises:
            Exception: If there is an error checking the node existence.
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                workspace_label = self._get_workspace_label()
                query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
                result = await session.run(query, entity_id=node_id)
                single_result = await result.single()
                await result.consume()  # Ensure result is fully consumed
                return (
                    single_result["node_exists"] if single_result is not None else False
                )
            except Exception as e:
                logger.error(f"Error checking node existence for {node_id}: {str(e)}")
                await result.consume()  # Ensure the result is consumed even on error
                raise

    async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
        """
        Check if an edge exists between two nodes in the graph.

        Args:
            source_node_id: The ID of the source node.
            target_node_id: The ID of the target node.

        Returns:
            bool: True if the edge exists, False otherwise.

        Raises:
            Exception: If there is an error checking the edge existence.
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                workspace_label = self._get_workspace_label()
                query = (
                    f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
                    "RETURN COUNT(r) > 0 AS edgeExists"
                )
                result = await session.run(
                    query,
                    source_entity_id=source_node_id,
                    target_entity_id=target_node_id,
                )  # type: ignore
                single_result = await result.single()
                await result.consume()  # Ensure result is fully consumed
                return (
                    single_result["edgeExists"] if single_result is not None else False
                )
            except Exception as e:
                logger.error(
                    f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
                )
                await result.consume()  # Ensure the result is consumed even on error
                raise

    async def get_node(self, node_id: str) -> dict[str, str] | None:
        """Get node by its label identifier, return only node properties

        Args:
            node_id: The node label to look up

        Returns:
            dict: Node properties if found
            None: If node not found

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                workspace_label = self._get_workspace_label()
                query = (
                    f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
                )
                result = await session.run(query, entity_id=node_id)
                try:
                    records = await result.fetch(
                        2
                    )  # Get 2 records for duplication check

                    if len(records) > 1:
                        logger.warning(
                            f"Multiple nodes found with label '{node_id}'. Using first node."
                        )
                    if records:
                        node = records[0]["n"]
                        node_dict = dict(node)
                        # Remove workspace label from labels list if it exists
                        if "labels" in node_dict:
                            node_dict["labels"] = [
                                label
                                for label in node_dict["labels"]
                                if label != workspace_label
                            ]
                        return node_dict
                    return None
                finally:
                    await result.consume()  # Ensure result is fully consumed
            except Exception as e:
                logger.error(f"Error getting node for {node_id}: {str(e)}")
                raise

    async def node_degree(self, node_id: str) -> int:
        """Get the degree (number of relationships) of a node with the given label.
        If multiple nodes have the same label, returns the degree of the first node.
        If no node is found, returns 0.

        Args:
            node_id: The label of the node

        Returns:
            int: The number of relationships the node has, or 0 if no node found

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                workspace_label = self._get_workspace_label()
                query = f"""
                    MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
                    OPTIONAL MATCH (n)-[r]-()
                    RETURN COUNT(r) AS degree
                """
                result = await session.run(query, entity_id=node_id)
                try:
                    record = await result.single()

                    if not record:
                        logger.warning(f"No node found with label '{node_id}'")
                        return 0

                    degree = record["degree"]
                    return degree
                finally:
                    await result.consume()  # Ensure result is fully consumed
            except Exception as e:
                logger.error(f"Error getting node degree for {node_id}: {str(e)}")
                raise

    async def get_all_labels(self) -> list[str]:
        """
        Get all existing node labels in the database
        Returns:
            ["Person", "Company", ...]  # Alphabetically sorted label list

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                workspace_label = self._get_workspace_label()
                query = f"""
                MATCH (n:`{workspace_label}`)
                WHERE n.entity_id IS NOT NULL
                RETURN DISTINCT n.entity_id AS label
                ORDER BY label
                """
                result = await session.run(query)
                labels = []
                async for record in result:
                    labels.append(record["label"])
                await result.consume()
                return labels
            except Exception as e:
                logger.error(f"Error getting all labels: {str(e)}")
                await result.consume()  # Ensure the result is consumed even on error
                raise

    async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
        """Retrieves all edges (relationships) for a particular node identified by its label.

        Args:
            source_node_id: Label of the node to get edges for

        Returns:
            list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
            None: If no edges found

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        try:
            async with self._driver.session(
                database=self._DATABASE, default_access_mode="READ"
            ) as session:
                try:
                    workspace_label = self._get_workspace_label()
                    query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
                            OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
                            WHERE connected.entity_id IS NOT NULL
                            RETURN n, r, connected"""
                    results = await session.run(query, entity_id=source_node_id)

                    edges = []
                    async for record in results:
                        source_node = record["n"]
                        connected_node = record["connected"]

                        # Skip if either node is None
                        if not source_node or not connected_node:
                            continue

                        source_label = (
                            source_node.get("entity_id")
                            if source_node.get("entity_id")
                            else None
                        )
                        target_label = (
                            connected_node.get("entity_id")
                            if connected_node.get("entity_id")
                            else None
                        )

                        if source_label and target_label:
                            edges.append((source_label, target_label))

                    await results.consume()  # Ensure results are consumed
                    return edges
                except Exception as e:
                    logger.error(
                        f"Error getting edges for node {source_node_id}: {str(e)}"
                    )
                    await results.consume()  # Ensure results are consumed even on error
                    raise
        except Exception as e:
            logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
            raise

    async def get_edge(
        self, source_node_id: str, target_node_id: str
    ) -> dict[str, str] | None:
        """Get edge properties between two nodes.

        Args:
            source_node_id: Label of the source node
            target_node_id: Label of the target node

        Returns:
            dict: Edge properties if found, default properties if not found or on error

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                workspace_label = self._get_workspace_label()
                query = f"""
                MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
                RETURN properties(r) as edge_properties
                """
                result = await session.run(
                    query,
                    source_entity_id=source_node_id,
                    target_entity_id=target_node_id,
                )
                records = await result.fetch(2)
                await result.consume()
                if records:
                    edge_result = dict(records[0]["edge_properties"])
                    for key, default_value in {
                        "weight": 0.0,
                        "source_id": None,
                        "description": None,
                        "keywords": None,
                    }.items():
                        if key not in edge_result:
                            edge_result[key] = default_value
                            logger.warning(
                                f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
                            )
                    return edge_result
                return None
            except Exception as e:
                logger.error(
                    f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
                )
                await result.consume()  # Ensure the result is consumed even on error
                raise

    async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
        """
        Upsert a node in the Neo4j database.

        Args:
            node_id: The unique identifier for the node (used as label)
            node_data: Dictionary of node properties
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        properties = node_data
        entity_type = properties["entity_type"]
        if "entity_id" not in properties:
            raise ValueError("Neo4j: node properties must contain an 'entity_id' field")

        try:
            async with self._driver.session(database=self._DATABASE) as session:
                workspace_label = self._get_workspace_label()

                async def execute_upsert(tx: AsyncManagedTransaction):
                    query = f"""
                    MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
                    SET n += $properties
                    SET n:`{entity_type}`
                    """
                    result = await tx.run(
                        query, entity_id=node_id, properties=properties
                    )
                    await result.consume()  # Ensure result is fully consumed

                await session.execute_write(execute_upsert)
        except Exception as e:
            logger.error(f"Error during upsert: {str(e)}")
            raise

    async def upsert_edge(
        self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
    ) -> None:
        """
        Upsert an edge and its properties between two nodes identified by their labels.
        Ensures both source and target nodes exist and are unique before creating the edge.
        Uses entity_id property to uniquely identify nodes.

        Args:
            source_node_id (str): Label of the source node (used as identifier)
            target_node_id (str): Label of the target node (used as identifier)
            edge_data (dict): Dictionary of properties to set on the edge

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        try:
            edge_properties = edge_data
            async with self._driver.session(database=self._DATABASE) as session:

                async def execute_upsert(tx: AsyncManagedTransaction):
                    workspace_label = self._get_workspace_label()
                    query = f"""
                    MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
                    WITH source
                    MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
                    MERGE (source)-[r:DIRECTED]-(target)
                    SET r += $properties
                    RETURN r, source, target
                    """
                    result = await tx.run(
                        query,
                        source_entity_id=source_node_id,
                        target_entity_id=target_node_id,
                        properties=edge_properties,
                    )
                    try:
                        await result.fetch(2)
                    finally:
                        await result.consume()  # Ensure result is consumed

                await session.execute_write(execute_upsert)
        except Exception as e:
            logger.error(f"Error during edge upsert: {str(e)}")
            raise

    async def delete_node(self, node_id: str) -> None:
        """Delete a node with the specified label

        Args:
            node_id: The label of the node to delete

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )

        async def _do_delete(tx: AsyncManagedTransaction):
            workspace_label = self._get_workspace_label()
            query = f"""
            MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
            DETACH DELETE n
            """
            result = await tx.run(query, entity_id=node_id)
            logger.debug(f"Deleted node with label {node_id}")
            await result.consume()

        try:
            async with self._driver.session(database=self._DATABASE) as session:
                await session.execute_write(_do_delete)
        except Exception as e:
            logger.error(f"Error during node deletion: {str(e)}")
            raise

    async def remove_nodes(self, nodes: list[str]):
        """Delete multiple nodes

        Args:
            nodes: List of node labels to be deleted
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        for node in nodes:
            await self.delete_node(node)

    async def remove_edges(self, edges: list[tuple[str, str]]):
        """Delete multiple edges

        Args:
            edges: List of edges to be deleted, each edge is a (source, target) tuple

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        for source, target in edges:

            async def _do_delete_edge(tx: AsyncManagedTransaction):
                workspace_label = self._get_workspace_label()
                query = f"""
                MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
                DELETE r
                """
                result = await tx.run(
                    query, source_entity_id=source, target_entity_id=target
                )
                logger.debug(f"Deleted edge from '{source}' to '{target}'")
                await result.consume()  # Ensure result is fully consumed

            try:
                async with self._driver.session(database=self._DATABASE) as session:
                    await session.execute_write(_do_delete_edge)
            except Exception as e:
                logger.error(f"Error during edge deletion: {str(e)}")
                raise

    async def drop(self) -> dict[str, str]:
        """Drop all data from the current workspace and clean up resources

        This method will delete all nodes and relationships in the Memgraph database.

        Returns:
            dict[str, str]: Operation status and message
            - On success: {"status": "success", "message": "data dropped"}
            - On failure: {"status": "error", "message": "<error details>"}

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        try:
            async with self._driver.session(database=self._DATABASE) as session:
                workspace_label = self._get_workspace_label()
                query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
                result = await session.run(query)
                await result.consume()
                logger.info(
                    f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
                )
                return {"status": "success", "message": "workspace data dropped"}
        except Exception as e:
            logger.error(
                f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
            )
            return {"status": "error", "message": str(e)}

    async def edge_degree(self, src_id: str, tgt_id: str) -> int:
        """Get the total degree (sum of relationships) of two nodes.

        Args:
            src_id: Label of the source node
            tgt_id: Label of the target node

        Returns:
            int: Sum of the degrees of both nodes
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        src_degree = await self.node_degree(src_id)
        trg_degree = await self.node_degree(tgt_id)

        # Convert None to 0 for addition
        src_degree = 0 if src_degree is None else src_degree
        trg_degree = 0 if trg_degree is None else trg_degree

        degrees = int(src_degree) + int(trg_degree)
        return degrees

    async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
        """Get all nodes that are associated with the given chunk_ids.

        Args:
            chunk_ids: List of chunk IDs to find associated nodes for

        Returns:
            list[dict]: A list of nodes, where each node is a dictionary of its properties.
                        An empty list if no matching nodes are found.
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        workspace_label = self._get_workspace_label()
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            query = f"""
            UNWIND $chunk_ids AS chunk_id
            MATCH (n:`{workspace_label}`)
            WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
            RETURN DISTINCT n
            """
            result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
            nodes = []
            async for record in result:
                node = record["n"]
                node_dict = dict(node)
                node_dict["id"] = node_dict.get("entity_id")
                nodes.append(node_dict)
            await result.consume()
            return nodes

    async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
        """Get all edges that are associated with the given chunk_ids.

        Args:
            chunk_ids: List of chunk IDs to find associated edges for

        Returns:
            list[dict]: A list of edges, where each edge is a dictionary of its properties.
                        An empty list if no matching edges are found.
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )
        workspace_label = self._get_workspace_label()
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            query = f"""
            UNWIND $chunk_ids AS chunk_id
            MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
            WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
            WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
            // Ensure we only return each unique edge once by ordering the source and target
            WITH a, b, r,
                 CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
                 CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
            RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
            """
            result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
            edges = []
            async for record in result:
                edge_properties = record["properties"]
                edge_properties["source"] = record["source"]
                edge_properties["target"] = record["target"]
                edges.append(edge_properties)
            await result.consume()
            return edges

    async def get_knowledge_graph(
        self,
        node_label: str,
        max_depth: int = 3,
        max_nodes: int = MAX_GRAPH_NODES,
    ) -> KnowledgeGraph:
        """
        Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.

        Args:
            node_label: Label of the starting node, * means all nodes
            max_depth: Maximum depth of the subgraph, Defaults to 3
            max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000

        Returns:
            KnowledgeGraph object containing nodes and edges, with an is_truncated flag
            indicating whether the graph was truncated due to max_nodes limit

        Raises:
            Exception: If there is an error executing the query
        """
        if self._driver is None:
            raise RuntimeError(
                "Memgraph driver is not initialized. Call 'await initialize()' first."
            )

        result = KnowledgeGraph()
        seen_nodes = set()
        seen_edges = set()
        workspace_label = self._get_workspace_label()
        async with self._driver.session(
            database=self._DATABASE, default_access_mode="READ"
        ) as session:
            try:
                if node_label == "*":
                    # First check if database has any nodes
                    count_query = "MATCH (n) RETURN count(n) as total"
                    count_result = None
                    total_count = 0
                    try:
                        count_result = await session.run(count_query)
                        count_record = await count_result.single()
                        if count_record:
                            total_count = count_record["total"]
                            if total_count == 0:
                                logger.debug("No nodes found in database")
                                return result
                            if total_count > max_nodes:
                                result.is_truncated = True
                                logger.info(
                                    f"Graph truncated: {total_count} nodes found, limited to {max_nodes}"
                                )
                    finally:
                        if count_result:
                            await count_result.consume()

                    # Run the main query to get nodes with highest degree
                    main_query = f"""
                    MATCH (n:`{workspace_label}`)
                    OPTIONAL MATCH (n)-[r]-()
                    WITH n, COALESCE(count(r), 0) AS degree
                    ORDER BY degree DESC
                    LIMIT $max_nodes
                    WITH collect(n) AS kept_nodes
                    MATCH (a)-[r]-(b)
                    WHERE a IN kept_nodes AND b IN kept_nodes
                    RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
                        collect(DISTINCT r) AS relationships
                    """
                    result_set = None
                    try:
                        result_set = await session.run(
                            main_query, {"max_nodes": max_nodes}
                        )
                        record = await result_set.single()
                        if not record:
                            logger.debug("No record returned from main query")
                            return result
                    finally:
                        if result_set:
                            await result_set.consume()

                else:
                    bfs_query = f"""
                    MATCH (start:`{workspace_label}`)
                    WHERE start.entity_id = $entity_id
                    WITH start
                    CALL {{
                        WITH start
                        MATCH path = (start)-[*0..{max_depth}]-(node)
                        WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
                        UNWIND path_nodes AS n
                        WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
                        WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
                        RETURN all_nodes, all_rels
                    }}
                    WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
                    WITH
                    CASE
                        WHEN total_nodes <= {max_nodes} THEN nodes
                        ELSE nodes[0..{max_nodes}]
                    END AS limited_nodes,
                    relationships,
                    total_nodes,
                    total_nodes > {max_nodes} AS is_truncated
                    RETURN
                    [node IN limited_nodes | {{node: node}}] AS node_info,
                    relationships,
                    total_nodes,
                    is_truncated
                    """
                    result_set = None
                    try:
                        result_set = await session.run(
                            bfs_query,
                            {
                                "entity_id": node_label,
                            },
                        )
                        record = await result_set.single()
                        if not record:
                            logger.debug(f"No nodes found for entity_id: {node_label}")
                            return result

                        # Check if the query indicates truncation
                        if "is_truncated" in record and record["is_truncated"]:
                            result.is_truncated = True
                            logger.info(
                                f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
                            )

                    finally:
                        if result_set:
                            await result_set.consume()

                # Process the record if it exists
                if record and record["node_info"]:
                    for node_info in record["node_info"]:
                        node = node_info["node"]
                        node_id = node.id
                        if node_id not in seen_nodes:
                            seen_nodes.add(node_id)
                            result.nodes.append(
                                KnowledgeGraphNode(
                                    id=f"{node_id}",
                                    labels=[node.get("entity_id")],
                                    properties=dict(node),
                                )
                            )

                    for rel in record["relationships"]:
                        edge_id = rel.id
                        if edge_id not in seen_edges:
                            seen_edges.add(edge_id)
                            start = rel.start_node
                            end = rel.end_node
                            result.edges.append(
                                KnowledgeGraphEdge(
                                    id=f"{edge_id}",
                                    type=rel.type,
                                    source=f"{start.id}",
                                    target=f"{end.id}",
                                    properties=dict(rel),
                                )
                            )

                logger.info(
                    f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
                )

            except Exception as e:
                logger.error(f"Error getting knowledge graph: {str(e)}")
                # Return empty but properly initialized KnowledgeGraph on error
                return KnowledgeGraph()

        return result
